Overview of the gadget Package

The gadget package provides a framework for building interpretable, regionally-partitioned decision trees based on local feature effect estimates (such as ICE/PDP or ALE curves). The core workflow is as follows:

The package is modular and extensible: different effect strategies (e.g., partial dependence, accumulated local effects) can be implemented by extending the strategy interface. This design allows users to interpret complex black-box models by partitioning the feature space into regions with distinct, interpretable effect patterns.

Synthetic data

ICE/PDP Method: Get feature effects

set.seed(123)
n = 5000
x1 = runif(n, -1, 1)
x2 = runif(n, -1, 1)
x3 = runif(n, -1, 1)
y = ifelse(x3 > 0, 3 * x1, -3 * x1) + x3 + rnorm(n, sd = 0.3)
syn.data = data.frame(x1, x2, x3, y)

syn.task = TaskRegr$new("xor", backend = syn.data, target = "y")
syn.learner = lrn("regr.ranger")
syn.learner$train(syn.task)
syn.predictor = Predictor$new(syn.learner, data = syn.data[, c("x1", "x2", "x3")], y = syn.data$y)
syn.effect = FeatureEffects$new(syn.predictor, grid.size = 20, method = "ice")

ICE/PDP Method: Fit and visualize the explanation tree

syn.tree.pd = gadgetTree$new(strategy = pdStrategy$new(), n.split = 4, impr.par = 0.1, min.node.size = 1)
syn.tree.pd$fit(effect = syn.effect, data = syn.data, target.feature.name = "y")
syn.tree.pd$plot_tree_structure()

syn.esi.pd = syn.tree.pd$extract_split_info()
print(syn.esi.pd)
##   id depth n.obs node.type split.feature  split.value objective.value    intImp
## 1  1     1  5000      root            x3 0.0001516948      507271.287 0.9867277
## 2  2     2  2555      left          <NA>           NA        3871.077        NA
## 3  3     2  2445     right          <NA>           NA        2861.605        NA
##   intImp.parent intImp.x1  intImp.x2 intImp.x3 split.feature.parent
## 1            NA 0.9904222 0.03722123 0.9975999                 <NA>
## 2     0.9867277        NA         NA        NA                   x3
## 3     0.9867277        NA         NA        NA                   x3
##     split.value.parent objective.value.parent is.final  time
## 1                 <NA>                     NA    FALSE 0.027
## 2 0.000151694752275944               507271.3     TRUE    NA
## 3 0.000151694752275944               507271.3     TRUE    NA
syn.tree.pd$plot(syn.effect, syn.data, target.feature.name = "y",
  show.plot = TRUE, show.point = FALSE, mean.center = TRUE)

### ALE Method

syn.tree.ale = gadgetTree$new(strategy = aleStrategy$new(), n.split = 3)
syn.tree.ale$fit(model = syn.learner, data = syn.data, target.feature.name = "y", n.intervals = 10)
syn.tree.ale$plot_tree_structure()

syn.esi.ale = syn.tree.ale$extract_split_info()
print(syn.esi.ale)
##   id depth n.obs node.type split.feature  split.value objective.value    intImp
## 1  1     1  5000      root            x3 -0.006593762       6084.6506 0.9248943
## 2  2     2  2536      left          <NA>           NA        223.5526        NA
## 3  3     2  2464     right          <NA>           NA        233.4393        NA
##   intImp.parent intImp.x1  intImp.x2 intImp.x3 split.feature.parent
## 1            NA 0.9153289 0.01917008 0.9845032                 <NA>
## 2     0.9248943        NA         NA        NA                   x3
## 3     0.9248943        NA         NA        NA                   x3
##   split.value.parent objective.value.parent is.final  time
## 1                 NA                     NA    FALSE 0.406
## 2       -0.006593762               6084.651     TRUE    NA
## 3       -0.006593762               6084.651     TRUE    NA
#object_size(syn.tree.ale)

Bikeshare data

ICE/PDP Method: Get feature effects

library(ISLR2)
data(Bikeshare)
set.seed(123)
bike = data.table(Bikeshare[sample(1:8645, 1000), ])

# bike.X = bike[, .(day, hr, temp, windspeed, weekday, workingday, hum, season, mnth, holiday, registered, weathersit, atemp, casual)]
bike.X = bike[, .(hr, temp, workingday)]
bike.y = bike$bikers
train = cbind(bike.X, "target" = bike.y)
bike.data = as.data.frame(train)

set.seed(123)
bike.task = TaskRegr$new(id = "bike", backend = bike.data, target = "target")
bike.learner = lrn("regr.ranger")
bike.learner$train(bike.task)

bike.X = bike.task$data(cols = bike.task$feature_names)
bike.y = bike.task$data(cols = bike.task$target_names)[[1]]

bike.predictor = Predictor$new(model = bike.learner, data = bike.X, y = bike.y)

effect.all = FeatureEffects$new(bike.predictor, method = "ice",
  grid.size = 20)

ICE/PDP Method: Fit and visualize the explanation tree

bike.tree.pd = gadgetTree$new(strategy = pdStrategy$new(), n.split = 4)
bike.tree.pd$fit(effect = effect.all, data = bike.data, target.feature.name = "target")
bike.tree.pd$plot_tree_structure()

bike.esi.pd = bike.tree.pd$extract_split_info()
print(bike.esi.pd)
##   id depth n.obs node.type split.feature split.value objective.value    intImp
## 1  1     1  1000      root    workingday        0.50      28009105.5 0.5420253
## 2  2     2   316      left          temp        0.45       4364418.6 0.1361772
## 3  3     2   684     right          temp        0.51       8463044.3 0.2528883
## 4  4     3   148      left          <NA>          NA        264791.5        NA
## 5  5     3   168     right          <NA>          NA        285424.9        NA
## 6  6     3   345      left          <NA>          NA        727370.8        NA
## 7  7     3   339     right          <NA>          NA        652498.2        NA
##   intImp.parent intImp.hr intImp.temp intImp.workingday split.feature.parent
## 1            NA 0.4742788   0.1014083      1.000000e+00                 <NA>
## 2     0.5420253 0.1461560   0.2846363      1.752207e-36           workingday
## 3     0.5420253 0.2798375   0.5141399      6.246519e-36           workingday
## 4     0.1361772        NA          NA                NA                 temp
## 5     0.1361772        NA          NA                NA                 temp
## 6     0.2528883        NA          NA                NA                 temp
## 7     0.2528883        NA          NA                NA                 temp
##   split.value.parent objective.value.parent is.final  time
## 1               <NA>                     NA    FALSE 0.004
## 2                0.5               28009105    FALSE 0.002
## 3                0.5               28009105    FALSE 0.003
## 4               0.45                4364419     TRUE    NA
## 5               0.45                4364419     TRUE    NA
## 6               0.51                8463044     TRUE    NA
## 7               0.51                8463044     TRUE    NA
bike.tree.pd$plot(effect.all, bike.data, target.feature.name = "target",
  show.plot = TRUE, show.point = TRUE, mean.center = FALSE,
  depth = c(2,3),
  node.id = 2:7,
  features = c("hr", "temp")
)

### ALE Method

bike.tree.ale = gadgetTree$new(strategy = aleStrategy$new(), n.split = 3)
bike.tree.ale$fit(model = bike.learner, data = bike.data, target.feature.name = "target", n.intervals = 10)
bike.tree.ale$plot_tree_structure()

bike.esi.ale = bike.tree.ale$extract_split_info()
print(bike.esi.ale)
##   id depth n.obs node.type split.feature split.value objective.value    intImp
## 1  1     1  1000      root    workingday        0.50      1835142.68 0.1969098
## 2  2     2   316      left          <NA>          NA        54221.14        NA
## 3  3     2   684     right          temp        0.83      1419563.98 1.0603092
## 4  6     3   666      left          <NA>          NA       191359.09        NA
## 5  7     3    18     right          <NA>          NA      -717613.77        NA
##   intImp.parent   intImp.hr intImp.temp intImp.workingday split.feature.parent
## 1            NA  0.68140007   0.2024446      1.332139e-15                 <NA>
## 2     0.1969098          NA          NA                NA           workingday
## 3     0.1969098 -0.09142195  -0.2841276      1.654368e+00           workingday
## 4     1.0603092          NA          NA                NA                 temp
## 5     1.0603092          NA          NA                NA                 temp
##   split.value.parent objective.value.parent is.final  time
## 1               <NA>                     NA    FALSE 0.069
## 2                0.5                1835143     TRUE    NA
## 3                0.5                1835143    FALSE 0.048
## 4               0.83                1419564     TRUE    NA
## 5               0.83                1419564     TRUE    NA
object_size(bike.esi.ale)
## 4.02 kB

Speed

boxplot(time ~ depth, data = syn.esi.pd, main = "Distribution of split time per depth - Syn.PD")

boxplot(time ~ depth, data = syn.esi.ale, main = "Distribution of split time per depth - Syn.ALE")

boxplot(time ~ depth, data = bike.esi.pd, main = "Distribution of split time per depth - Bike.PD")

boxplot(time ~ depth, data = bike.esi.ale, main = "Distribution of split time per depth - Bike.ALE")

gadgetTree fit benchmark

Data generation

set.seed(1)
options(future.globals.maxSize = 4 * 1024 * 1024^2) # 4GB
plan(sequential)

datagen_p5 = function(n, seed = 1) {
  set.seed(seed)
  x1 = round(runif(n, -1, 1), 1)
  x2 = round(runif(n, -1, 1), 3)
  x3 = as.factor(sample(c(0, 1), size = n, replace = TRUE, prob = c(0.5, 0.5)))
  x4 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.7, 0.3))
  x5 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.5, 0.5))
  dat = data.frame(x1, x2, x3, x4, x5)
  y = 0.2 * x1 - 8 * x2 + ifelse(x3 == 0, 16 * x2, 0) + ifelse(x1 > 0, 8 * x2, 0)
  eps = rnorm(n, 0, 0.1 * sd(y))
  y = y + eps
  dat$y = y
  X = dat[, setdiff(colnames(dat), "y")]
  mod = ranger(y ~ ., data = dat, num.trees = 100)
  pred = function(model, newdata) predict(model, newdata)$predictions
  model = Predictor$new(mod, data = X, y = dat$y, predict.function = pred)
  eff = FeatureEffects$new(model, method = "ice", grid.size = 20)
  list(dat = dat, eff = eff)
}

datagen_p10 = function(n, seed = 1) {
  set.seed(seed)
  x1 = round(runif(n, -1, 1), 1)
  x2 = round(runif(n, -1, 1), 3)
  x3 = as.factor(sample(c(0, 1), size = n, replace = TRUE, prob = c(0.5, 0.5)))
  x4 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.7, 0.3))
  x5 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.5, 0.5))
  x6 = rnorm(n, mean = 1, sd = 5)
  x7 = round(rnorm(n, mean = 10, sd = 10), 2)
  x8 = round(rnorm(n, mean = 100, sd = 15), 4)
  x9 = round(rnorm(n, mean = 1000, sd = 20), 7)
  x10 = rnorm(n, mean = 10000, sd = 25)
  dat = data.frame(x1, x2, x3, x4, x5, x6, x7, x8, x9, x10)
  y = 0.2 * x1 - 8 * x2 + ifelse(x3 == 0, 16 * x2, 0) + ifelse(x1 > 0, 8 * x2, 0) + 2 * x8
  eps = rnorm(n, 0, 0.1 * sd(y))
  y = y + eps
  dat$y = y
  X = dat[, setdiff(colnames(dat), "y")]
  mod = ranger(y ~ ., data = dat, num.trees = 100)
  pred = function(model, newdata) predict(model, newdata)$predictions
  model = Predictor$new(mod, data = X, y = dat$y, predict.function = pred)
  eff = FeatureEffects$new(model, method = "ice", grid.size = 20)
  list(dat = dat, eff = eff)
}

datagen_p20 = function(n, seed = 1) {
  set.seed(seed)
  x1 = round(runif(n, -1, 1), 1)
  x2 = round(runif(n, -1, 1), 3)
  x3 = as.factor(sample(c(0, 1), size = n, replace = TRUE, prob = c(0.5, 0.5)))
  x4 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.7, 0.3))
  x5 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.5, 0.5))
  x6 = rnorm(n, mean = 1, sd = 5)
  x7 = round(rnorm(n, mean = 10, sd = 10), 2)
  x8 = round(rnorm(n, mean = 100, sd = 15), 4)
  x9 = round(rnorm(n, mean = 1000, sd = 20), 7)
  x10 = rnorm(n, mean = 10000, sd = 25)
  noise = replicate(10, rnorm(n), simplify = FALSE)
  names(noise) = paste0("noise", 1:10)
  dat = data.frame(x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, noise)
  y = 0.2 * x1 - 8 * x2 + ifelse(x3 == 0, 16 * x2, 0) + ifelse(x1 > 0, 8 * x2, 0) + 2 * x8
  eps = rnorm(n, 0, 0.1 * sd(y))
  y = y + eps
  dat$y = y
  X = dat[, setdiff(colnames(dat), "y")]
  mod = ranger(y ~ ., data = dat, num.trees = 100)
  pred = function(model, newdata) predict(model, newdata)$predictions
  model = Predictor$new(mod, data = X, y = dat$y, predict.function = pred)
  eff = FeatureEffects$new(model, method = "ice", grid.size = 20)
  list(dat = dat, eff = eff)
}

Run experiments

n_list = c(1000, 5000, 10000)
p_list = c(5, 10, 20)

bench_results = list()
tree_sizes = data.frame(n = integer(), p = integer(), tree_size_MB = numeric(), 
                       mem_before_MB = numeric(), mem_after_MB = numeric(), mem_increase_MB = numeric())
# Initial memory cleanup
gc()
##            used  (Mb) gc trigger  (Mb) limit (Mb) max used  (Mb)
## Ncells  2899530 154.9    5215075 278.6         NA  5215075 278.6
## Vcells 14070167 107.4   40000535 305.2      32768 39999922 305.2
initial_mem = gc()["Vcells", "used"]
for (n in n_list) {
  for (p in p_list) {
    cat(sprintf("Running: n = %d, p = %d\n", n, p))
    # Clean memory and record starting state
    gc()
    mem_before = gc()["Vcells", "used"]
    # Data generation
    if (p == 5) {
      sim = datagen_p5(n)
    } else if (p == 10) {
      sim = datagen_p10(n)
    } else if (p == 20) {
      sim = datagen_p20(n)
    }
    # Clean memory after data generation
    gc()
    # Create and fit tree
    tree = gadgetTree$new(strategy = pdStrategy$new(), n.split = 10)
    tree$fit(effect = sim$eff, data = sim$dat, target.feature.name = "y")
    # Clean memory after tree fitting
    gc()
    # Calculate tree size and memory usage
    tree_size_MB = as.numeric(pryr::object_size(tree)) / 1024^2
    mem_after = gc()["Vcells", "used"]
    mem_increase = mem_after - mem_before
    # Record results
    tree_sizes = rbind(tree_sizes, data.frame(
      n = n, 
      p = p, 
      tree_size_MB = tree_size_MB,
      mem_before_MB = mem_before,
      mem_after_MB = mem_after,
      mem_increase_MB = mem_increase
    ))
    # Clean up tree object
    rm(tree)
    gc()
    # Benchmark with memory monitoring
    res = bench::mark(
      fit = {
        # Clean memory
        gc()
        tree = gadgetTree$new(strategy = pdStrategy$new(), n.split = 10)
        tree$fit(effect = sim$eff, data = sim$dat, target.feature.name = "y")
        # Clean memory
        gc()
      },
      iterations = 5
    )
    res$n = n
    res$p = p
    bench_results[[paste0("n", n, "_p", p)]] = res
    # Clean up sim data
    rm(sim)
    gc()
    cat(sprintf("Memory used: %.2f MB, Tree size: %.2f MB\n", mem_increase, tree_size_MB))
  }
}
## Running: n = 1000, p = 5
## Memory used: 522373.00 MB, Tree size: 3.49 MB
## Running: n = 1000, p = 10
## Memory used: 1131087.00 MB, Tree size: 4.09 MB
## Running: n = 1000, p = 20
## Memory used: 2054541.00 MB, Tree size: 4.12 MB
## Running: n = 5000, p = 5
## Memory used: 2302209.00 MB, Tree size: 3.54 MB
## Running: n = 5000, p = 10
## Memory used: 5585787.00 MB, Tree size: 5.58 MB
## Running: n = 5000, p = 20
## Memory used: 10170724.00 MB, Tree size: 5.93 MB
## Running: n = 10000, p = 5
## Memory used: 4357172.00 MB, Tree size: 3.60 MB
## Running: n = 10000, p = 10
## Memory used: 11148419.00 MB, Tree size: 7.05 MB
## Running: n = 10000, p = 20
## Memory used: 20286252.00 MB, Tree size: 6.07 MB
# Final memory cleanup
gc()
##            used  (Mb) gc trigger  (Mb) limit (Mb)  max used  (Mb)
## Ncells  3653258 195.2    6298091 336.4         NA   6298091 336.4
## Vcells 53709992 409.8  120246076 917.5      32768 120245385 917.4
final_mem = gc()["Vcells", "used"]
cat(sprintf("Total memory increase: %.2f MB\n", final_mem - initial_mem))
## Total memory increase: 39677448.00 MB

Collect and visualize results

bench_all = do.call(rbind, bench_results)
n_vec = rep(bench_all$n, each = 5)
p_vec = rep(bench_all$p, each = 5)
time_vec = unlist(bench_all$time)
time_ms = as.numeric(time_vec) * 1000
bench_long = data.frame(
  n = n_vec,
  p = p_vec,
  time_ms = time_ms
)

ggplot(bench_long, aes(x = factor(n), y = time_ms, color = factor(p), group = p)) +
  geom_boxplot(aes(group = interaction(n, p))) +
  geom_jitter(width = 0.2, alpha = 0.5) +
  labs(x = "Sample Size (n)", y = "Fit Time (ms)", color = "Feature Number (p)",
       title = "gadgetTree$fit(n.split = 10) Benchmark: Varying n and p") +
  theme_minimal()

tree_sizes
##       n  p tree_size_MB mem_before_MB mem_after_MB mem_increase_MB
## 1  1000  5     3.492287      14034366     14556739          522373
## 2  1000 10     4.094444      14146699     15277786         1131087
## 3  1000 20     4.116432      16188294     18242835         2054541
## 4  5000  5     3.538010      19804925     22107134         2302209
## 5  5000 10     5.582108      19894630     25480417         5585787
## 6  5000 20     5.927483      24622692     34793416        10170724
## 7 10000  5     3.595230      35806504     40163676         4357172
## 8 10000 10     7.049728      35852011     47000430        11148419
## 9 10000 20     6.073471      43342535     63628787        20286252

Simplified Risk Calculation

\[ \begin{aligned} SSE &= \sum_{i=1}^n(y_i-\bar{y})^2\\ &= \sum_{i=1}^ny_i^2-2\bar{y}\sum_{i=1}^ny_i+\sum^n\bar{y}^2\\ &= \sum_{i=1}^ny_i^2-2\bar{y}n\bar{y}+n\bar{y}^2\\ &= \sum_{i=1}^ny_i^2 - n\bar{y}^2\\ &= \sum_{i=1}^ny_i^2-n(\frac{\sum_{i=1}^ny_i}{n})^2\\ &= \sum_{i=1}^ny_i^2-\frac{1}{n}(\sum_{i=1}^ny_i)^2\\ &= SS -\frac{S^2}{n} \end{aligned} \] \[ \begin{aligned} SSE_{Reduction} &= SSE_{parent}-SSE_{left}-SSE_{right}\\ &= SS_{parent} -\frac{S_{parent}^2}{n_{parent}}-(SS_{left} -\frac{S_{left}^2}{n_{left}})-(SS_{right} -\frac{S_{right}^2}{n_{right}}) \end{aligned} \] Since \[n\_{parent} = n\_{left} + n\_{right}\\ SS\_{parent} = SS\_{left} + SS\_{right}\]

Then \[ SSE\_{Reduction} = -\frac{S_{parent}^2}{n_{parent}} +\frac{S_{left}^2}{n_{left}} +\frac{S_{right}^2}{n_{right}}\\ max(SSE\_{Reduction}) = max(\frac{S_{left}^2}{n_{left}} +\frac{S_{right}^2}{n_{right}})=min(-\frac{S_{left}^2}{n_{left}} -\frac{S_{right}^2}{n_{right}}) \]